import torch
import torch.nn as nn


class AngularPenaltySMLoss(nn.Module):
    def __init__(self, loss_type="cosface", eps=1e-7, s=20, m=0.0):
        super(AngularPenaltySMLoss, self).__init__()
        loss_type = loss_type
        assert loss_type in [
            "arcface",
            "sphereface",
            "cosface",
            "crossentropy",
        ]
        if loss_type == "arcface":
            self.s = 64.0 if not s else s
            self.m = 0.5 if not m else m
        if loss_type == "sphereface":
            self.s = 64.0 if not s else s
            self.m = 1.35 if not m else m
        if loss_type == "cosface":
            self.s = 20.0 if not s else s
            self.m = 0.0 if not m else m
        self.loss_type = loss_type
        self.eps = eps

        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, wf, labels):
        if self.loss_type == "crossentropy":
            return self.cross_entropy(wf, labels)
        else:
            if self.loss_type == "cosface":
                numerator = self.s * (
                    torch.diagonal(wf.transpose(0, 1)[labels]) - self.m
                )
            elif self.loss_type == "arcface":
                numerator = self.s * torch.cos(
                    torch.acos(
                        torch.clamp(
                            torch.diagonal(wf.transpose(0, 1)[labels]),
                            -1.0 + self.eps,
                            1 - self.eps,
                        )
                    )
                    + self.m
                )
            elif self.loss_type == "sphereface":
                numerator = self.s * torch.cos(
                    self.m
                    * torch.acos(
                        torch.clamp(
                            torch.diagonal(wf.transpose(0, 1)[labels]),
                            -1.0 + self.eps,
                            1 - self.eps,
                        )
                    )
                )
            else:
                raise ValueError("Invalid loss type")

            excl = torch.cat(
                [
                    torch.cat((wf[i, :y], wf[i, y + 1 :])).unsqueeze(0)
                    for i, y in enumerate(labels)
                ],
                dim=0,
            )
            denominator = torch.exp(numerator) + torch.sum(
                torch.exp(self.s * excl), dim=1
            )
            L = numerator - torch.log(denominator)
            return -torch.mean(L)
